{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Network Architecture Searching with Eve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import torch as th\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import eve\n",
    "import eve.app\n",
    "import eve.app.model\n",
    "import eve.app.trainer\n",
    "import eve.core\n",
    "import eve.app.space as space\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# build a basic network for trainer\n",
    "class mnist(eve.core.Eve):\n",
    "    def __init__(self, neuron_wise: bool = False):\n",
    "        super().__init__()\n",
    "\n",
    "        eve.core.State.register_global_statistic(\"l1_norm\")\n",
    "        eve.core.State.register_global_statistic(\"kl_div\")\n",
    "\n",
    "        self.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(1, 4, 3, stride=2, padding=1),\n",
    "            nn.BatchNorm2d(4),\n",
    "        )\n",
    "        # use IFNode to act as ReLU\n",
    "        self.node1 = eve.core.IFNode(eve.core.State(self.conv1), binary=False)\n",
    "        self.quan1 = eve.core.Quantizer(eve.core.State(self.conv1),\n",
    "                                        upgrade_bits=True,\n",
    "                                        neuron_wise=neuron_wise,)\n",
    "\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(4, 8, 3, stride=2, padding=1),\n",
    "            nn.BatchNorm2d(8),\n",
    "        )\n",
    "        self.node2 = eve.core.IFNode(eve.core.State(self.conv2), binary=False)\n",
    "        self.quan2 = eve.core.Quantizer(eve.core.State(self.conv2),\n",
    "                                        upgrade_bits=True,\n",
    "                                        neuron_wise=neuron_wise,)\n",
    "\n",
    "        self.conv3 = nn.Sequential(\n",
    "            nn.Conv2d(8, 16, 3, stride=2, padding=1),\n",
    "            nn.BatchNorm2d(16),\n",
    "        )\n",
    "        self.node3 = eve.core.IFNode(eve.core.State(self.conv3), binary=False)\n",
    "        self.quan3 = eve.core.Quantizer(eve.core.State(self.conv3),\n",
    "                                        upgrade_bits=True,\n",
    "                                        neuron_wise=neuron_wise,)\n",
    "\n",
    "        self.linear1 = nn.Linear(16 * 4 * 4, 16)\n",
    "        self.node4 = eve.core.IFNode(eve.core.State(self.linear1))\n",
    "        self.quan4 = eve.core.Quantizer(eve.core.State(self.linear1),\n",
    "                                        upgrade_bits=True,\n",
    "                                        neuron_wise=neuron_wise,)\n",
    "\n",
    "        self.linear2 = nn.Linear(16, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        conv1 = self.conv1(x)\n",
    "        node1 = self.node1(conv1)\n",
    "        quan1 = self.quan1(node1)\n",
    "\n",
    "        conv2 = self.conv2(quan1)\n",
    "        node2 = self.node2(conv2)\n",
    "        quan2 = self.quan2(node2)\n",
    "\n",
    "        conv3 = self.conv3(quan2)\n",
    "        node3 = self.node3(conv3)\n",
    "        quan3 = self.quan3(node3)\n",
    "\n",
    "        quan3 = th.flatten(quan3, start_dim=1).unsqueeze(dim=1)\n",
    "\n",
    "        linear1 = self.linear1(quan3)\n",
    "        node4 = self.node4(linear1)\n",
    "        quan4 = self.quan4(node4)\n",
    "\n",
    "        linear2 = self.linear2(quan4)\n",
    "\n",
    "        return linear2.squeeze(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MnistClassifier(eve.app.model.Classifier):\n",
    "    def prepare_data(self, data_root: str):\n",
    "        from torch.utils.data import DataLoader, random_split\n",
    "        from torchvision import transforms\n",
    "        from torchvision.datasets import MNIST\n",
    "\n",
    "        train_dataset = MNIST(root=data_root,\n",
    "                              train=True,\n",
    "                              download=True,\n",
    "                              transform=transforms.ToTensor())\n",
    "        test_dataset = MNIST(root=data_root,\n",
    "                             train=False,\n",
    "                             download=True,\n",
    "                             transform=transforms.ToTensor())\n",
    "        self.train_dataset, self.valid_dataset = random_split(\n",
    "            train_dataset, [55000, 5000])\n",
    "        self.test_dataset = test_dataset\n",
    "\n",
    "        self.train_dataloader = DataLoader(self.train_dataset,\n",
    "                                           batch_size=128,\n",
    "                                           shuffle=True,\n",
    "                                           num_workers=4)\n",
    "        self.test_dataloader = DataLoader(self.test_dataset,\n",
    "                                          batch_size=128,\n",
    "                                          shuffle=False,\n",
    "                                          num_workers=4)\n",
    "        self.valid_dataloader = DataLoader(self.valid_dataset,\n",
    "                                           batch_size=128,\n",
    "                                           shuffle=False,\n",
    "                                           num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MnistTrainer(eve.app.trainer.BaseTrainer):\n",
    "    def reset(self) -> np.ndarray:\n",
    "        \"\"\"Evaluate current trainer, reload trainer and then return the initial obs.\n",
    "\n",
    "        Returns:\n",
    "            obs: np.ndarray, the initial observation of trainer.\n",
    "        \"\"\"\n",
    "        # do a fast valid\n",
    "        self.steps += 1\n",
    "        if self.steps % self.eval_steps == 0:\n",
    "            self.steps = 0\n",
    "            finetune_acc = self.valid()[\"acc\"]\n",
    "            # eval model\n",
    "            if finetune_acc > self.finetune_acc:\n",
    "                self.finetune_acc = finetune_acc\n",
    "            # reset model to explore more posibility\n",
    "            self.load_from_RAM()\n",
    "\n",
    "        # save best model which achieve higher reward\n",
    "        if self.accumulate_reward > self.best_reward:\n",
    "            self.cache_to_RAM()\n",
    "            self.best_reward = self.accumulate_reward\n",
    "\n",
    "        # clear accumulate reward\n",
    "        self.accumulate_reward = 0.0\n",
    "\n",
    "        # reset related last value\n",
    "        # WRAN: don't forget to reset self._obs_gen and self._last_eve_obs to None.\n",
    "        # somtimes, the episode may be interrupted, but the gen do not reset.\n",
    "        self.last_eve = None\n",
    "        self.obs_generator = None\n",
    "        self.upgrader.zero_obs()\n",
    "        self.fit_step()\n",
    "        return self.fetch_obs()\n",
    "\n",
    "    def close(self):\n",
    "        \"\"\"Override close in your subclass to perform any necessary cleanup.\n",
    "\n",
    "        Environments will automatically close() themselves when\n",
    "        garbage collected or when the program exits.\n",
    "        \"\"\"\n",
    "        # load best model first\n",
    "        self.load_from_RAM()\n",
    "\n",
    "        finetune_acc = self.test()[\"acc\"]\n",
    "\n",
    "        bits = 0\n",
    "        bound = 0\n",
    "        for v in self.upgrader.eve_parameters():\n",
    "            bits = bits + th.floor(v.mean() * 8)\n",
    "            bound = bound + 8\n",
    "\n",
    "        bits = bits.item()\n",
    "\n",
    "        print(\n",
    "            f\"baseline: {self.baseline_acc}, ours: {finetune_acc}, bits: {bits} / {bound}\"\n",
    "        )\n",
    "\n",
    "        if self.tensorboard_log is not None:\n",
    "            save_path = self.kwargs.get(\n",
    "                \"save_path\", os.path.join(self.tensorboard_log, \"model.ckpt\"))\n",
    "            self.save_checkpoint(path=save_path)\n",
    "            print(f\"save trained model to {save_path}\")\n",
    "\n",
    "    def reward(self) -> float:\n",
    "        \"\"\"A simple reward function.\n",
    "\n",
    "        You have to rewrite this function based on your tasks.\n",
    "        \"\"\"\n",
    "        self.upgrader.zero_obs()\n",
    "\n",
    "        info = self.fit_step()\n",
    "\n",
    "        return info[\"acc\"] - self.last_eve.mean().item() * 0.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('buffer_size', 2000),\n",
      "             ('gamma', 0.98),\n",
      "             ('gradient_steps', -1),\n",
      "             ('learning_rate', 0.001),\n",
      "             ('learning_starts', 1000),\n",
      "             ('n_episodes_rollout', 1),\n",
      "             ('n_timesteps', 1000000.0),\n",
      "             ('noise_std', 0.1),\n",
      "             ('noise_type', 'normal'),\n",
      "             ('policy', 'MlpPolicy'),\n",
      "             ('policy_kwargs', 'dict(net_arch=[400, 300])')])\n",
      "Using 1 environments\n",
      "Overwriting n_timesteps with n=100000\n",
      "Applying normal noise with std 0.1\n",
      "Using cuda device\n",
      "Log path: examples/logs/ddpg/mnist_trainer_2\n",
      "---------------------------------\n",
      "| rollout/           |          |\n",
      "|    ep_len_mean     | 4        |\n",
      "|    ep_rew_mean     | 1.67     |\n",
      "| time/              |          |\n",
      "|    episodes        | 100      |\n",
      "|    fps             | 38       |\n",
      "|    time_elapsed    | 10       |\n",
      "|    total timesteps | 400      |\n",
      "---------------------------------\n",
      "---------------------------------\n",
      "| rollout/           |          |\n",
      "|    ep_len_mean     | 4        |\n",
      "|    ep_rew_mean     | 2.18     |\n",
      "| time/              |          |\n",
      "|    episodes        | 200      |\n",
      "|    fps             | 37       |\n",
      "|    time_elapsed    | 21       |\n",
      "|    total timesteps | 800      |\n",
      "---------------------------------\n",
      "---------------------------------\n",
      "| rollout/           |          |\n",
      "|    ep_len_mean     | 4        |\n",
      "|    ep_rew_mean     | 1.43     |\n",
      "| time/              |          |\n",
      "|    episodes        | 300      |\n",
      "|    fps             | 4        |\n",
      "|    time_elapsed    | 258      |\n",
      "|    total timesteps | 1200     |\n",
      "| train/             |          |\n",
      "|    actor_loss      | -0.699   |\n",
      "|    critic_loss     | 0.0808   |\n",
      "|    learning_rate   | 0.001    |\n",
      "|    n_updates       | 196      |\n",
      "---------------------------------\n",
      "baseline: 0.0, ours: 0.8412776898734177, bits: 12.0 / 32\n",
      "Saving to examples/logs/ddpg/mnist_trainer_2\n"
     ]
    }
   ],
   "source": [
    "# define a mnist classifier\n",
    "neuron_wise = True\n",
    "sample_episode = False\n",
    "\n",
    "mnist_classifier = MnistClassifier(mnist(neuron_wise))\n",
    "mnist_classifier.prepare_data(data_root=\"/home/densechen/dataset\")\n",
    "mnist_classifier.setup_train()  # use default configuration\n",
    "\n",
    "# set mnist classifier to quantization mode\n",
    "mnist_classifier.quantize()\n",
    "\n",
    "# set neurons and states\n",
    "# if neuron wise, we just set neurons as the member of max neurons of the network\n",
    "# else set it to 1.\n",
    "mnist_classifier.set_neurons(16 if neuron_wise else 1)\n",
    "mnist_classifier.set_states(1)\n",
    "\n",
    "# None will use a default case\n",
    "mnist_classifier.set_action_space(None)\n",
    "mnist_classifier.set_observation_space(None)\n",
    "\n",
    "# define a trainer\n",
    "MnistTrainer.assign_model(mnist_classifier)\n",
    "\n",
    "# define a experiment manager\n",
    "\n",
    "exp_manager = eve.app.ExperimentManager(\n",
    "    algo=\"ddpg\",\n",
    "    env_id=\"mnist_trainer\",\n",
    "    env=MnistTrainer,\n",
    "    log_folder=\"examples/logs\",\n",
    "    n_timesteps=100000,\n",
    "    save_freq=1000,\n",
    "    default_hyperparameter_yaml=\"hyperparams\",\n",
    "    log_interval=100,\n",
    "    sample_episode=sample_episode,\n",
    ")\n",
    "\n",
    "model = exp_manager.setup_experiment()\n",
    "\n",
    "exp_manager.learn(model)\n",
    "exp_manager.save_trained_model(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
